{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Tce3stUlHN0L"
      },
      "source": [
        "##### Copyright 2020 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "tuOe1ymfHZPu"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MfBg1C5NB3X0"
      },
      "source": [
        "# Model Averaging\n",
        "\n",
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/addons/tutorials/average_optimizers_callback\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/addons/blob/master/docs/tutorials/average_optimizers_callback.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/addons/blob/master/docs/tutorials/average_optimizers_callback.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
        "  </td>\n",
        "      <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/addons/docs/tutorials/average_optimizers_callback.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
        "  </td>\n",
        "</table>\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xHxb-dlhMIzW"
      },
      "source": [
        "## Overview\n",
        "\n",
        "This notebook demonstrates how to use Moving Average Optimizer along with the Model Average Checkpoint from tensorflow addons pagkage.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "o2UNySlpXkbl"
      },
      "source": [
        "## Moving Averaging \n",
        "\n",
        "> The advantage of Moving Averaging is that they are less prone to rampant loss shifts or irregular data representation in the latest batch. It gives a smooothened and a more genral idea of the model training until some point.\n",
        "\n",
        "## Stocastic Averaging\n",
        "\n",
        "> Stocastic Weight Averaging converges to wider optimas. By doing so, it resembles geometric ensembeling. SWA is a simple method to improve model performance when used as a wrapper around other optimizers and averaging results from different points of trajectory of the inner optimizer.\n",
        "\n",
        "## Model Average Checkpoint \n",
        "\n",
        "> `callbacks.ModelCheckpoint` doesn't give you the option to save moving average weights in the middle of training, which is why Model Average Optimizers required a custom callback. Using the ```update_weights``` parameter, ```ModelAverageCheckpoint``` allows you to:\n",
        "1.   Assign the moving average weights to the model, and save them.\n",
        "2.   Keep the old non-averaged weights, but the saved model uses the average weights."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MUXex9ctTuDB"
      },
      "source": [
        "## Setup"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "sXEOqj5cIgyW"
      },
      "outputs": [],
      "source": [
        "!pip install -U tensorflow-addons"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IqR2PQG4ZaZ0"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "import tensorflow_addons as tfa"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4hnJ2rDpI38-"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "import os"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Iox_HZNNYLEB"
      },
      "source": [
        "## Build Model "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KtylpxOmceaC"
      },
      "outputs": [],
      "source": [
        "def create_model(opt):\n",
        "    model = tf.keras.models.Sequential([\n",
        "        tf.keras.layers.Flatten(),                         \n",
        "        tf.keras.layers.Dense(64, activation='relu'),\n",
        "        tf.keras.layers.Dense(64, activation='relu'),\n",
        "        tf.keras.layers.Dense(10, activation='softmax')\n",
        "    ])\n",
        "\n",
        "    model.compile(optimizer=opt,\n",
        "                    loss='sparse_categorical_crossentropy',\n",
        "                    metrics=['accuracy'])\n",
        "\n",
        "    return model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pwdM2pl3RSPb"
      },
      "source": [
        "## Prepare Dataset"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mMOeXVmbdilM"
      },
      "outputs": [],
      "source": [
        "#Load Fashion MNIST dataset\n",
        "train, test = tf.keras.datasets.fashion_mnist.load_data()\n",
        "\n",
        "images, labels = train\n",
        "images = images/255.0\n",
        "labels = labels.astype(np.int32)\n",
        "\n",
        "fmnist_train_ds = tf.data.Dataset.from_tensor_slices((images, labels))\n",
        "fmnist_train_ds = fmnist_train_ds.shuffle(5000).batch(32)\n",
        "\n",
        "test_images, test_labels = test"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iEbhI_eajpJe"
      },
      "source": [
        "We will be comparing three optimizers here:\n",
        "\n",
        "*   Unwrapped SGD\n",
        "*   SGD with Moving Average\n",
        "*   SGD with Stochastic Weight Averaging\n",
        "\n",
        "And see how they perform with the same model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_Q76K1fNk7Va"
      },
      "outputs": [],
      "source": [
        "#Optimizers \n",
        "sgd = tf.keras.optimizers.SGD(0.01)\n",
        "moving_avg_sgd = tfa.optimizers.MovingAverage(sgd)\n",
        "stocastic_avg_sgd = tfa.optimizers.SWA(sgd)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nXlMX4p9qHwg"
      },
      "source": [
        "Both ```MovingAverage``` and ```StocasticAverage``` optimers use ```ModelAverageCheckpoint```."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SnvZjt34qEHY"
      },
      "outputs": [],
      "source": [
        "#Callback \n",
        "checkpoint_path = \"./training/cp-{epoch:04d}.ckpt\"\n",
        "checkpoint_dir = os.path.dirname(checkpoint_path)\n",
        "\n",
        "cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_dir,\n",
        "                                                 save_weights_only=True,\n",
        "                                                 verbose=1)\n",
        "avg_callback = tfa.callbacks.AverageModelCheckpoint(filepath=checkpoint_dir, \n",
        "                                                    update_weights=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uabQmjMtRtzs"
      },
      "source": [
        "## Train Model\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SPmifETHmPix"
      },
      "source": [
        "### Vanilla SGD Optimizer "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Xy8W4LYppadJ"
      },
      "outputs": [],
      "source": [
        "#Build Model\n",
        "model = create_model(sgd)\n",
        "\n",
        "#Train the network\n",
        "model.fit(fmnist_train_ds, epochs=5, callbacks=[cp_callback])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "uU2iQ6HAZ6-E"
      },
      "outputs": [],
      "source": [
        "#Evalute results\n",
        "model.load_weights(checkpoint_dir)\n",
        "loss, accuracy = model.evaluate(test_images, test_labels, batch_size=32, verbose=2)\n",
        "print(\"Loss :\", loss)\n",
        "print(\"Accuracy :\", accuracy)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lAvhD4unmc6W"
      },
      "source": [
        "### Moving Average SGD"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "--NIjBp-mhVb"
      },
      "outputs": [],
      "source": [
        "#Build Model\n",
        "model = create_model(moving_avg_sgd)\n",
        "\n",
        "#Train the network\n",
        "model.fit(fmnist_train_ds, epochs=5, callbacks=[avg_callback])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zRAym9EBmnW9"
      },
      "outputs": [],
      "source": [
        "#Evalute results\n",
        "model.load_weights(checkpoint_dir)\n",
        "loss, accuracy = model.evaluate(test_images, test_labels, batch_size=32, verbose=2)\n",
        "print(\"Loss :\", loss)\n",
        "print(\"Accuracy :\", accuracy)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "K98lbU07m_Bk"
      },
      "source": [
        "### Stocastic Weight Average SGD "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Ia7ALKefnXWQ"
      },
      "outputs": [],
      "source": [
        "#Build Model\n",
        "model = create_model(stocastic_avg_sgd)\n",
        "\n",
        "#Train the network\n",
        "model.fit(fmnist_train_ds, epochs=5, callbacks=[avg_callback])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "EOT2E9NBoeHI"
      },
      "outputs": [],
      "source": [
        "#Evalute results\n",
        "model.load_weights(checkpoint_dir)\n",
        "loss, accuracy = model.evaluate(test_images, test_labels, batch_size=32, verbose=2)\n",
        "print(\"Loss :\", loss)\n",
        "print(\"Accuracy :\", accuracy)"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [
        "Tce3stUlHN0L"
      ],
      "name": "average_optimizers_callback.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
